{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Effect of the sample size in cross-validation\n",
    "\n",
    "In the previous notebook, we presented the general cross-validation framework\n",
    "and how to assess if a predictive model is underfitting, overfitting, or\n",
    "generalizing. Besides these aspects, it is also important to understand how\n",
    "the different errors are influenced by the number of samples available.\n",
    "\n",
    "In this notebook, we will show this aspect by looking at the variability of\n",
    "the different errors.\n",
    "\n",
    "Let's first load the data and create the same model as in the previous\n",
    "notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import fetch_california_housing\n",
    "\n",
    "housing = fetch_california_housing(as_frame=True)\n",
    "data, target = housing.data, housing.target\n",
    "target *= 100  # rescale the target in k$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"admonition note alert alert-info\">\n",
    "<p class=\"first admonition-title\" style=\"font-weight: bold;\">Note</p>\n",
    "<p class=\"last\">If you want a deeper overview regarding this dataset, you can refer to the\n",
    "Appendix - Datasets description section at the end of this MOOC.</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.tree import DecisionTreeRegressor\n",
    "\n",
    "regressor = DecisionTreeRegressor()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Learning curve\n",
    "\n",
    "To understand the impact of the number of samples available for training on\n",
    "the generalization performance of a predictive model, it is possible to\n",
    "synthetically reduce the number of samples used to train the predictive model\n",
    "and check the training and testing errors.\n",
    "\n",
    "Therefore, we can vary the number of samples in the training set and repeat\n",
    "the experiment. The training and testing scores can be plotted similarly to\n",
    "the validation curve, but instead of varying a hyperparameter, we vary the\n",
    "number of training samples. This curve is called the **learning curve**.\n",
    "\n",
    "It gives information regarding the benefit of adding new training samples\n",
    "to improve a model's generalization performance.\n",
    "\n",
    "Let's compute the learning curve for a decision tree and vary the\n",
    "proportion of the training set from 10% to 100%."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "train_sizes = np.linspace(0.1, 1.0, num=5, endpoint=True)\n",
    "train_sizes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use a `ShuffleSplit` cross-validation to assess our predictive model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import ShuffleSplit\n",
    "\n",
    "cv = ShuffleSplit(n_splits=30, test_size=0.2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we are all set to carry out the experiment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import LearningCurveDisplay\n",
    "\n",
    "display = LearningCurveDisplay.from_estimator(\n",
    "    regressor,\n",
    "    data,\n",
    "    target,\n",
    "    train_sizes=train_sizes,\n",
    "    cv=cv,\n",
    "    score_type=\"both\",  # both train and test errors\n",
    "    scoring=\"neg_mean_absolute_error\",\n",
    "    negate_score=True,  # to use when metric starts with \"neg_\"\n",
    "    score_name=\"Mean absolute error (k$)\",\n",
    "    std_display_style=\"errorbar\",\n",
    "    n_jobs=2,\n",
    ")\n",
    "_ = display.ax_.set(xscale=\"log\", title=\"Learning curve for decision tree\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Looking at the training error alone, we see that we get an error of 0 k$. It\n",
    "means that the trained model (i.e. decision tree) is clearly overfitting the\n",
    "training data.\n",
    "\n",
    "Looking at the testing error alone, we observe that the more samples are\n",
    "added into the training set, the lower the testing error becomes. Also, we\n",
    "are searching for the plateau of the testing error for which there is no\n",
    "benefit to adding samples anymore or assessing the potential gain of adding\n",
    "more samples into the training set.\n",
    "\n",
    "If the testing error plateaus despite adding more training samples, it's\n",
    "possible that the model has achieved its optimal performance. In this case,\n",
    "using a more expressive model might help reduce the error further. Otherwise,\n",
    "the error may have reached the Bayes error rate, the theoretical minimum error\n",
    "due to inherent uncertainty not resolved by the available data. This minimum error is\n",
    "non-zero whenever some of the variation of the target variable `y` depends on\n",
    "external factors not fully observed in the features available in `X`, which is\n",
    "almost always the case in practice.\n",
    "\n",
    "## Summary\n",
    "\n",
    "In the notebook, we learnt:\n",
    "\n",
    "* the influence of the number of samples in a dataset, especially on the\n",
    "  variability of the errors reported when running the cross-validation;\n",
    "* about the learning curve, which is a visual representation of the capacity\n",
    "  of a model to improve by adding new samples."
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "main_language": "python"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}